from textgrad import logger
from textgrad.variable import Variable
from textgrad.engine import EngineLM
from .function import Function, BackwardContext
from typing import Callable, Dict, List

CONVERSATION_TEMPLATE_STRING = (
    "Function purpose: {function_purpose}\n\n"
    "<INPUTS_TO_FUNCTION> {inputs_string} </INPUTS_TO_FUNCTION>\n\n"
    "<OUTPUT_OF_FUNCTION> {response_value} </OUTPUT_OF_FUNCTION>\n\n"
)

# Has the gradient on the output.
CONVERSATION_START_INSTRUCTION_STRING_FN_CHAIN = (
    "You will give feedback to a variable with the following role: <ROLE> {variable_desc} </ROLE>. "
    "Here is an evaluation of a string-based function with inputs and outputs :\n\n"
    "{conversation}"
)

# Does not have gradient on the output
CONVERSATION_START_INSTRUCTION_STRING_FN_BASE = (
    "You will give feedback to a variable with the following role: <ROLE> {variable_desc} </ROLE>. "
    "Here is an evaluation of the variable using a string-based function:\n\n"
    "{conversation}"
)

OBJECTIVE_INSTRUCTION_CHAIN = (
    "This conversation is part of a larger system. The <OUTPUT_OF_FUNCTION> was later used as {response_desc}.\n\n"
    "<OBJECTIVE_FUNCTION>Your goal is to give feedback to the variable to address the following feedback on the OUTPUT_OF_FUNCTION: {response_gradient} </OBJECTIVE_FUNCTION>\n\n"
)

OBJECTIVE_INSTRUCTION_BASE = (
    "<OBJECTIVE_FUNCTION>Your goal is to give feedback and criticism to the variable given the above evaluation output. "
    "Our only goal is to improve the above metric, and nothing else. </OBJECTIVE_FUNCTION>\n\n"
)

# Some instructions for the backward pass are shared with LLMs
from .llm_backward_prompts import (
    EVALUATE_VARIABLE_INSTRUCTION,
    BACKWARD_SYSTEM_PROMPT
)

class StringBasedFunction(Function):
    def __init__(self, fn: Callable, function_purpose: str):
        """
        Autograd function for string-based functions.

        :param fn: The function to execute for the forward pass.
        :type fn: Callable
        :param function_purpose: The description of the purpose of the function. Analogous to role description for variables.
        :type function_purpose: str
        """
        super().__init__()
        self.fn = fn
        self.function_purpose = function_purpose

    def forward(self, 
                inputs: Dict[str, Variable], 
                response_role_description: str = None) -> Variable:
        """
        The forward mode for string-based functions

        :param inputs: The arguments that will be passed to the string based function. The keys are the names of the arguments.
        :type fn: Dict[str, Variable]
        :param response_role_description: The role description of the output variable. 
        :type response_role_description: str
        """        
        if response_role_description is None:
            response_role_description = f"Output of the string-based function with purpose: {self.function_purpose}"
        response_string = self.fn(**inputs)

        # Create the response variable
        response = Variable(
            value=response_string,
            predecessors=list(inputs.values()),
            role_description=response_role_description
        )
        
        logger.info(f"StringBasedFunction", extra={"text": f"In: {inputs}, Out: {response_string}"})
        
        # Populate the gradient function, using a container to store the backward function and the context
        response.set_grad_fn(BackwardContext(backward_fn=self.backward, 
                                             response=response, 
                                             function_purpose=self.function_purpose,
                                             inputs=inputs))

        return response
    
    def backward(self, response: Variable, 
                 function_purpose: str, 
                 inputs: Dict[str, Variable], 
                 backward_engine: EngineLM):
        children_variables = response.predecessors
        if response.get_gradient_text().strip() == "":
            self._backward_through_string_fn_base(children_variables, response, inputs, function_purpose, backward_engine)
        else:
            self._backward_through_string_fn_chain(children_variables, response, inputs, function_purpose, backward_engine)

    @staticmethod
    def backward_static(
        response: Variable,
        function_purpose: str,
        inputs: Dict[str, Variable],
        backward_engine: EngineLM,
    ):
        """
        Static method to perform backward propagation for string-based functions.
        
        This method handles the gradient computation for string-based functions by determining
        whether to use the base backward method (when no gradient exists on the response) or
        the chain backward method (when gradient information is available on the response).
        """
        children_variables = response.predecessors
        if response.get_gradient_text().strip() == "":
            StringBasedFunction._backward_through_string_fn_base(
                children_variables, response, inputs, function_purpose, backward_engine
            )
        else:
            StringBasedFunction._backward_through_string_fn_chain(
                children_variables, response, inputs, function_purpose, backward_engine
            )

    @staticmethod
    def _construct_string_fn_chain_backward_prompt(backward_info: dict[str, str]) -> str:
        conversation = CONVERSATION_TEMPLATE_STRING.format(**backward_info)
        backward_prompt = CONVERSATION_START_INSTRUCTION_STRING_FN_CHAIN.format(conversation=conversation, **backward_info)
        backward_prompt += OBJECTIVE_INSTRUCTION_CHAIN.format(**backward_info)
        backward_prompt += EVALUATE_VARIABLE_INSTRUCTION.format(**backward_info)
        return backward_prompt

    @staticmethod
    def _backward_through_string_fn_chain(variables: List[Variable],
                                          response: Variable,
                                          inputs: Dict[str, Variable],
                                          function_purpose: str, 
                                          backward_engine: EngineLM):
        inputs_string = "\n\n".join([f"**{k.replace('_', ' ').capitalize()}(role: {v.get_role_description()})**: {v.get_short_value()}" for k, v in inputs.items()])
        
        for variable in variables:
            if not variable.requires_grad:
                continue

            backward_info = {
                "response_desc": response.get_role_description(),
                "response_value": response.get_value(),
                "response_gradient": response.get_gradient_text(),
                "function_purpose": function_purpose,
                "inputs_string": inputs_string,
                "variable_desc": variable.get_role_description(),
                "variable_short": variable.get_short_value()
            }
            
            backward_prompt = StringBasedFunction._construct_string_fn_chain_backward_prompt(backward_info)

            logger.info(f"_backward_through_string_fn", extra={"_backward_through_string_fn": backward_prompt})
            gradient_value = backward_engine(backward_prompt, system_prompt=BACKWARD_SYSTEM_PROMPT)
            logger.info(f"_backward_through_string_fn gradient", extra={"_backward_through_string_fn": gradient_value})
            
            var_gradients = Variable(value=gradient_value, role_description=f"feedback to {variable.get_role_description()}")
            variable.gradients.add(var_gradients)
            conversation = CONVERSATION_TEMPLATE_STRING.format(**backward_info)
            variable.gradients_context[var_gradients] = {
                "context": conversation, 
                "response_desc": response.get_role_description(),
                "variable_desc": variable.get_role_description()
            }
            
            if response._reduce_meta:
                var_gradients._reduce_meta.extend(response._reduce_meta)
                variable._reduce_meta.extend(response._reduce_meta)

    @staticmethod
    def _construct_string_fn_base_backward_prompt(backward_info: dict[str, str]) -> str:
        conversation = CONVERSATION_TEMPLATE_STRING.format(**backward_info)
        backward_prompt = CONVERSATION_START_INSTRUCTION_STRING_FN_BASE.format(conversation=conversation, **backward_info)
        backward_prompt += OBJECTIVE_INSTRUCTION_BASE.format(**backward_info)
        backward_prompt += EVALUATE_VARIABLE_INSTRUCTION.format(**backward_info)
        return backward_prompt

    @staticmethod
    def _backward_through_string_fn_base(variables: List[Variable],
                                         response: Variable,
                                         inputs: Dict[str, Variable],
                                         function_purpose: str, 
                                         backward_engine: EngineLM):
        inputs_string = "\n\n".join([f"**{k.replace('_', ' ').capitalize()}(role: {v.get_role_description()})**: {v.get_short_value()}" for k, v in inputs.items()])
        
        for variable in variables:
            if not variable.requires_grad:
                continue

            backward_info = {
                "response_desc": response.get_role_description(),
                "response_value": response.get_value(),
                "response_gradient": response.get_gradient_text(),
                "function_purpose": function_purpose,
                "inputs_string": inputs_string,
                "variable_desc": variable.get_role_description(),
                "variable_short": variable.get_short_value()
            }
            backward_prompt = StringBasedFunction._construct_string_fn_base_backward_prompt(backward_info)
            
            logger.info(f"_backward_through_string_fn prompt", extra={"_backward_through_string_fn": backward_prompt})
            gradient_value = backward_engine(backward_prompt, system_prompt=BACKWARD_SYSTEM_PROMPT)
            logger.info(f"_backward_through_string_fn gradient", extra={"_backward_through_string_fn": gradient_value})

            conversation = CONVERSATION_TEMPLATE_STRING.format(**backward_info)
            var_gradients = Variable(value=gradient_value, role_description=f"feedback to {variable.get_role_description()}")
            variable.gradients.add(var_gradients)
            variable.gradients_context[var_gradients] = {
                "context": conversation, 
                "response_desc": response.get_role_description(),
                "variable_desc": variable.get_role_description()
            }

            if response._reduce_meta:
                var_gradients._reduce_meta.extend(response._reduce_meta)
                variable._reduce_meta.extend(response._reduce_meta)
